package com.stars.easyms.schedule.repository;

import com.stars.easyms.base.bean.LazyLoadBean;
import com.stars.easyms.schedule.bean.BatchResult;
import com.stars.easyms.schedule.enums.DbType;
import com.stars.easyms.schedule.exception.DistributedScheduleRuntimeException;
import com.stars.easyms.schedule.factory.DbScheduleSqlSessionFactory;
import com.stars.easyms.schedule.util.DistributedSchedulePackageUtil;
import org.apache.ibatis.executor.ErrorContext;
import org.apache.ibatis.executor.SimpleExecutor;
import org.apache.ibatis.executor.keygen.KeyGenerator;
import org.apache.ibatis.logging.jdbc.ConnectionLogger;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.scripting.defaults.DefaultParameterHandler;
import org.apache.ibatis.session.Configuration;
import org.mybatis.spring.transaction.SpringManagedTransaction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.jdbc.datasource.ConnectionHolder;
import org.springframework.jdbc.datasource.DataSourceTransactionManager;
import org.springframework.transaction.TransactionDefinition;
import org.springframework.transaction.TransactionStatus;
import org.springframework.transaction.support.DefaultTransactionDefinition;
import org.springframework.transaction.support.TransactionSynchronizationManager;

import javax.sql.DataSource;
import java.sql.BatchUpdateException;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.util.*;
import java.util.concurrent.atomic.AtomicLong;

/**
 * 批量提交DAO
 *
 * @author guoguifang
 */
public class DistributedScheduleBatchCommitDAO {

    private static final Logger logger = LoggerFactory.getLogger(DistributedScheduleBatchCommitDAO.class);

    private static final LazyLoadBean<DistributedScheduleBatchCommitDAO> DISTRIBUTED_SCHEDULE_BATCH_COMMIT_DAO_LAZY_LOAD_BEAN
            = new LazyLoadBean<>(DistributedScheduleBatchCommitDAO::new);

    private final LazyLoadBean<DataSource> dataSource = new LazyLoadBean<>(DataSource.class);

    private volatile DataSourceTransactionManager transactionManager;

    private final ThreadLocal<TransactionStatus> transactionStatus = new ThreadLocal<>();

    private static final int PER_COMMIT_SIZE = 5000;

    public <T> BatchResult<T> batchInsert(String sqlId, List<T> list) {
        return this.batchCommit(sqlId, convertOrdinalMap(list));
    }

    public <T> BatchResult<T> batchUpdate(String sqlId, List<T> list) {
        return this.batchCommit(sqlId, convertOrdinalMap(list));
    }

    public <T> BatchResult<T> batchDelete(String sqlId, List<T> list) {
        return this.batchCommit(sqlId, convertOrdinalMap(list));
    }

    private <T> BatchResult<T> batchCommit(String sqlId, Map<Integer, T> ordinalMap) {
        BatchResult<T> batchResult = new BatchResult<>();
        if (ordinalMap != null) {
            try {
                batchResult = batchCommit(new BatchCommitConfig(sqlId), ordinalMap, false);
                commitTransaction();
            } catch (Exception e) {
                logger.error("Batch processing data failed!", e);
                rollbackTransaction();
                batchResult.getSuccessDatas().clear();
                batchResult.getFailDatas().clear();
                batchResult.addFailDatas(ordinalMap);
            } finally {
                transactionStatus.remove();
            }
        }
        return batchResult;
    }

    @SuppressWarnings("unchecked")
    private <T> BatchResult<T> batchCommit(BatchCommitConfig batchCommitConfig, Map<Integer, T> ordinalMap, boolean isFailRetry) throws SQLException {
        BatchResult<T> batchResult = new BatchResult<>();
        Map<Integer, T> failRetryMap = new TreeMap<>();
        Map<String, List<PreparedStatementHolder>> preparedStatementMap = batchCommitConfig.getPreparedStatementMap();
        // 准备批量处理数据
        for (Map.Entry<Integer, T> entry : ordinalMap.entrySet()) {
            Integer index = entry.getKey();
            T t = entry.getValue();
            // 如果是失败后继续执行的不需要重新获取selectKey值
            if (!isFailRetry) {
                // 如果有selectKey则创建一个数据库连接获取selectKey的值并赋值给对象
                ErrorContext.instance().store();
                SpringManagedTransaction springManagedTransaction = new SpringManagedTransaction(dataSource.getNonNullBean());
                batchCommitConfig.keyGenerator.processBefore(new SimpleExecutor(
                        batchCommitConfig.configuration, springManagedTransaction), batchCommitConfig.mappedStatement, null, t);
                ErrorContext.instance().recall();
                springManagedTransaction.close();
            }
            // 创建PreparedStatement对象，每一条SQL对应一个PreparedStatement对象
            BoundSql boundSql = batchCommitConfig.mappedStatement.getBoundSql(t);
            String sql = boundSql.getSql();
            List<PreparedStatementHolder> preparedStatementList = preparedStatementMap.get(sql);
            PreparedStatementHolder preparedStatementHolder = null;
            if (preparedStatementList == null) {
                preparedStatementList = new ArrayList<>();
                preparedStatementHolder = new PreparedStatementHolder(batchCommitConfig.connection.prepareStatement(sql));
                preparedStatementList.add(preparedStatementHolder);
                preparedStatementMap.put(sql, preparedStatementList);
            }
            if (preparedStatementHolder == null) {
                preparedStatementHolder = preparedStatementList.get(preparedStatementList.size() - 1);
            }
            DefaultParameterHandler defaultParameterHandler = new DefaultParameterHandler(batchCommitConfig.mappedStatement, t, boundSql);
            defaultParameterHandler.setParameters(preparedStatementHolder.preparedStatement);
            preparedStatementHolder.preparedStatement.addBatch();
            preparedStatementHolder.subOrdinalMap.put(index, t);
            if (preparedStatementHolder.count.incrementAndGet() == PER_COMMIT_SIZE) {
                preparedStatementList.add(new PreparedStatementHolder(batchCommitConfig.connection.prepareStatement(sql)));
            }
        }
        // 执行批量数据处理
        for (List<PreparedStatementHolder> preparedStatementHolderList : preparedStatementMap.values()) {
            for (PreparedStatementHolder preparedStatementHolder : preparedStatementHolderList) {
                try {
                    preparedStatementHolder.preparedStatement.executeBatch();
                    batchResult.addSuccessDatas(preparedStatementHolder.subOrdinalMap);
                } catch (BatchUpdateException batchUpdateException) {
                    int[] updateCounts = batchUpdateException.getUpdateCounts();
                    if (updateCounts != null) {
                        int subOrdinalMapSize = preparedStatementHolder.subOrdinalMap.size();
                        if (DbType.ORACLE == DistributedSchedulePackageUtil.getDbType()) {
                            int successLength = updateCounts.length;
                            batchResult.addSuccessDatas(subMap(preparedStatementHolder.subOrdinalMap, 0, successLength));
                            if (subOrdinalMapSize > successLength) {
                                Map<Integer, T> subMap = subMap(preparedStatementHolder.subOrdinalMap, successLength, null);
                                logger.error("Batch processing data failed,sqlId：{},fail data：{}",
                                        batchCommitConfig.sqlId, subMap.values(), batchUpdateException);
                                batchResult.addFailDatas(subMap);
                                if (successLength + 1 < subOrdinalMapSize) {
                                    failRetryMap.putAll(subMap(preparedStatementHolder.subOrdinalMap, successLength + 1, subOrdinalMapSize));
                                }
                            }
                        } else if (DbType.MYSQL == DistributedSchedulePackageUtil.getDbType()) {
                            // 判断mysql是否使用了批量操作,mysql使用批量操作时不能准确判断成功数据因此出现错误时使用单条插入的方式
                            Map<Integer, T> failMap = new TreeMap<>();
                            for (int i = 0; i < updateCounts.length; i++) {
                                Map<Integer, T> subMap = subMap(preparedStatementHolder.subOrdinalMap, i, null);
                                if (updateCounts[i] == -3) {
                                    if (!DistributedSchedulePackageUtil.isAllowBatch()) {
                                        failMap.putAll(subMap);
                                        batchResult.addFailDatas(subMap);
                                    } else if (!singleCommit(batchCommitConfig, subMap.values().toArray()[0])) {
                                        // 为了减少执行次数，当遇到执行失败数据时把剩余未执行的数据重新批量处理
                                        batchResult.addFailDatas(subMap);
                                        failRetryMap.putAll(subMap(preparedStatementHolder.subOrdinalMap, i + 1, subOrdinalMapSize));
                                        break;
                                    } else {
                                        batchResult.addSuccessDatas(subMap);
                                    }
                                } else {
                                    batchResult.addSuccessDatas(subMap);
                                }
                            }
                            if (failMap.size() > 0) {
                                logger.error("Batch processing data failed,sqlId：{},fail data：{}",
                                        batchCommitConfig.sqlId, failMap.values(), batchUpdateException);
                            }
                        } else {
                            logger.error("Batch processing data failed,sqlId：{},fail data：{}",
                                    batchCommitConfig.sqlId, preparedStatementHolder.subOrdinalMap.values(), batchUpdateException);
                            batchResult.addFailDatas(preparedStatementHolder.subOrdinalMap);
                        }
                    } else {
                        logger.error("Batch processing data failed,sqlId：{},fail data：{}",
                                batchCommitConfig.sqlId, preparedStatementHolder.subOrdinalMap.values(), batchUpdateException);
                        batchResult.addFailDatas(preparedStatementHolder.subOrdinalMap);
                    }
                } catch (Exception exception) {
                    logger.error("Batch processing data failed,sqlId：{},fail data：{}",
                            batchCommitConfig.sqlId, preparedStatementHolder.subOrdinalMap.values(), exception);
                    batchResult.addFailDatas(preparedStatementHolder.subOrdinalMap);
                }
            }
        }
        // 如果不是整体事务则失败后重新执行错误数据之后的数据
        if (failRetryMap.size() > 0) {
            BatchResult<T> subBatchResult = batchCommit(batchCommitConfig, failRetryMap, true);
            batchResult.addSuccessDatas(subBatchResult.getSuccessOrdinalMap());
            batchResult.addFailDatas(subBatchResult.getFailOrdinalMap());
        }
        return batchResult;
    }

    /**
     * 单条数据执行数据库操作
     */
    private <T> boolean singleCommit(BatchCommitConfig batchCommitConfig, T t) {
        try {
            BoundSql boundSql = batchCommitConfig.mappedStatement.getBoundSql(t);
            String sql = boundSql.getSql();
            PreparedStatement preparedStatement = batchCommitConfig.connection.prepareStatement(sql);
            DefaultParameterHandler defaultParameterHandler = new DefaultParameterHandler(batchCommitConfig.mappedStatement, t, boundSql);
            defaultParameterHandler.setParameters(preparedStatement);
            preparedStatement.executeUpdate();
        } catch (Exception e) {
            logger.error("Single data execution failure after batch processing data fail,sqlId：{},fail data：{}",
                    batchCommitConfig.sqlId, t, e);
            return false;
        }
        return true;
    }

    private <T> Map<Integer, T> convertOrdinalMap(List<T> list) {
        int listSize;
        if (list != null && (listSize = list.size()) > 0) {
            Map<Integer, T> ordinalMap = new TreeMap<>();
            for (int i = 0; i < listSize; i++) {
                ordinalMap.put(i, list.get(i));
            }
            return ordinalMap;
        }
        return null;
    }

    private <T> Map<Integer, T> subMap(Map<Integer, T> ordinalMap, Integer start, Integer end) {
        Map<Integer, T> subOrdinalMap = new TreeMap<>();
        if (start == null || (end != null && (end <= 0 || start.equals(end)))) {
            return subOrdinalMap;
        }
        int count = 0;
        for (Map.Entry<Integer, T> entry : ordinalMap.entrySet()) {
            if (end != null) {
                if (count >= start && count < end) {
                    subOrdinalMap.put(entry.getKey(), entry.getValue());
                }
            } else if (count == start) {
                subOrdinalMap.put(entry.getKey(), entry.getValue());
                break;
            }
            count++;
        }
        return subOrdinalMap;
    }

    private class BatchCommitConfig {

        private final Connection connection;

        private final String sqlId;

        private final Configuration configuration;

        private final MappedStatement mappedStatement;

        private final KeyGenerator keyGenerator;

        private final Map<String, List<PreparedStatementHolder>> preparedStatementMap;

        private Map<String, List<PreparedStatementHolder>> getPreparedStatementMap() {
            this.preparedStatementMap.clear();
            return this.preparedStatementMap;
        }

        private BatchCommitConfig(String sqlId) {
            this.configuration = DbScheduleSqlSessionFactory.getSqlSessionFactory().getConfiguration();
            String fullSqlId = DistributedSchedulePackageUtil.getRepositoryClassName() + "." + sqlId;
            MappedStatement localMappedStatement = this.configuration.getMappedStatement(fullSqlId);
            if (localMappedStatement == null) {
                localMappedStatement = this.configuration.getMappedStatement(sqlId);
                this.sqlId = sqlId;
            } else {
                this.sqlId = fullSqlId;
            }
            this.mappedStatement = localMappedStatement;
            this.keyGenerator = localMappedStatement.getKeyGenerator();
            this.preparedStatementMap = new HashMap<>(32);
            this.connection = getConnection(localMappedStatement);
        }
    }

    private static class PreparedStatementHolder<T> {

        private final PreparedStatement preparedStatement;

        private final Map<Integer, T> subOrdinalMap;

        private final AtomicLong count;

        private PreparedStatementHolder(PreparedStatement preparedStatement) {
            this.preparedStatement = preparedStatement;
            this.subOrdinalMap = new TreeMap<>();
            this.count = new AtomicLong(0);
        }
    }

    /**
     * 获取数据库连接
     */
    private Connection getConnection(MappedStatement mappedStatement) {
        if (this.transactionManager == null) {
            synchronized (this) {
                if (this.transactionManager == null) {
                    this.transactionManager = new DataSourceTransactionManager(dataSource.getNonNullBean());
                }
            }
        }
        this.transactionStatus.set(this.transactionManager.getTransaction(new DefaultTransactionDefinition(TransactionDefinition.PROPAGATION_REQUIRES_NEW)));
        ConnectionHolder connectionHolder = (ConnectionHolder) TransactionSynchronizationManager.getResource(dataSource.getNonNullBean());
        Connection connection;
        if (connectionHolder != null) {
            connection = connectionHolder.getConnection();
        } else {
            throw new DistributedScheduleRuntimeException("Get datasource connection failure!");
        }
        return mappedStatement.getStatementLog().isDebugEnabled() ? ConnectionLogger.newInstance(connection, mappedStatement.getStatementLog(), 0) : connection;
    }

    /**
     * 提交事务
     */
    private void commitTransaction() {
        final TransactionStatus localTransactionStatus = this.transactionStatus.get();
        if (localTransactionStatus != null && !localTransactionStatus.isCompleted()) {
            if (localTransactionStatus.isRollbackOnly()) {
                this.transactionManager.rollback(localTransactionStatus);
            } else {
                this.transactionManager.commit(localTransactionStatus);
            }
        }
    }

    /**
     * 回滚事务
     */
    private void rollbackTransaction() {
        final TransactionStatus localTransactionStatus = this.transactionStatus.get();
        if (localTransactionStatus != null && !localTransactionStatus.isCompleted()) {
            this.transactionManager.rollback(localTransactionStatus);
        }
    }

    public static DistributedScheduleBatchCommitDAO getInstance() {
        return DISTRIBUTED_SCHEDULE_BATCH_COMMIT_DAO_LAZY_LOAD_BEAN.getBean();
    }

    private DistributedScheduleBatchCommitDAO() {
    }

}
